This document assesses the ability of several models to classify StraightenedWorm data. Manual annotations are used to train, validate, and select the best model and hyperparameters

First load the data:

library(tidymodels)
library(tidyverse)
library(here)

files <- tibble(path = list.files(path = here(),
                                  pattern = '.*_manual.csv',
                                  recursive = TRUE))

get_data <- function(...) {
  
  df <- tibble(...)
  
  data <- read_csv(here(df$path))
  
}

annotations <- files %>% 
  pmap_dfr(get_data) %>% 
  janitor::clean_names() %>% 
  filter(!is.na(worm)) %>% 
  select(worm, contains('area')) %>% 
  mutate(worm = case_when(
    worm == 'Y' ~ 'Single worm',
    worm == 'N' ~ 'Debris',
    worm == 'P' ~ 'Partial worm',
    worm == 'M' ~ 'Multiple worms',
  )) %>% 
  mutate(worm = as.factor(worm))

glimpse(annotations)
## Rows: 4,821
## Columns: 25
## $ worm                              <fct> Debris, Debris, Single worm, Single …
## $ area_shape_area                   <dbl> 5205, 2031, 2469, 2030, 2497, 2152, …
## $ area_shape_bounding_box_area      <dbl> 6069, 2394, 2919, 2394, 2919, 2541, …
## $ area_shape_bounding_box_maximum_x <dbl> 21, 42, 63, 84, 105, 126, 147, 168, …
## $ area_shape_bounding_box_maximum_y <dbl> 296, 121, 146, 121, 146, 128, 244, 1…
## $ area_shape_bounding_box_minimum_x <dbl> 0, 21, 42, 63, 84, 105, 126, 147, 16…
## $ area_shape_bounding_box_minimum_y <dbl> 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, …
## $ area_shape_center_x               <dbl> 10.13698, 31.13639, 51.49494, 72.997…
## $ area_shape_center_y               <dbl> 150.06724, 63.26686, 75.94208, 63.09…
## $ area_shape_compactness            <dbl> 6.075054, 2.603662, 3.150583, 2.5034…
## $ area_shape_eccentricity           <dbl> 0.9973402, 0.9832310, 0.9884558, 0.9…
## $ area_shape_equivalent_diameter    <dbl> 81.40769, 50.85223, 56.06807, 50.839…
## $ area_shape_euler_number           <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, …
## $ area_shape_extent                 <dbl> 0.8576372, 0.8483709, 0.8458376, 0.8…
## $ area_shape_form_factor            <dbl> 0.1646076, 0.3840745, 0.3174016, 0.3…
## $ area_shape_major_axis_length      <dbl> 302.82179, 119.52235, 144.60349, 117…
## $ area_shape_max_feret_diameter     <dbl> 288.00174, 113.03982, 138.00362, 113…
## $ area_shape_maximum_radius         <dbl> 11.00000, 11.00000, 10.19804, 11.000…
## $ area_shape_mean_radius            <dbl> 5.031499, 4.889537, 4.976384, 5.0222…
## $ area_shape_median_radius          <dbl> 5.000000, 5.000000, 5.000000, 5.0000…
## $ area_shape_min_feret_diameter     <dbl> 20, 20, 20, 20, 20, 20, 20, 20, 20, …
## $ area_shape_minor_axis_length      <dbl> 22.07184, 21.79666, 21.90884, 22.154…
## $ area_shape_orientation            <dbl> -0.066591089, -0.161518210, 0.040796…
## $ area_shape_perimeter              <dbl> 630.3625, 257.7817, 312.6518, 252.71…
## $ area_shape_solidity               <dbl> 0.9553965, 0.9562147, 0.9614486, 0.9…

Explore data

The Worm Toolbox in Cell Profiler can export a variety of features, some of which may be useful in classification.

library(ggbeeswarm)

annotations %>%
  pivot_longer(-worm, names_to = 'measurement', values_to = 'value') %>%
  ggplot() +
  geom_quasirandom(aes(x = worm, y = value, color = worm)) +
  facet_wrap(facets = vars(measurement), scales = 'free_y') +
  theme_minimal() +
  NULL

Build models

First create training (with cross-fold validation) and test data sets

model_data <- annotations %>%
  mutate(worm = factor(worm))

set.seed(123)
# data_boot <- bootstraps(model_data, times = 2) # only 2 bootstraps for testing
data_split <- initial_split(model_data,
                            strata = worm)
train_data <- training(data_split)
test_data <- testing(data_split)

set.seed(234)
folds <- vfold_cv(train_data,
                  v = 10,
                  strata = worm)

Specify the models (multinomial regression, decision tree, and random forest):

decision_tree_rpart_spec <-
  decision_tree(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
  set_engine('rpart') %>%
  set_mode('classification')

multinom_reg_glmnet_spec <-
  multinom_reg(penalty = tune(), mixture = tune()) %>%
  set_engine('glmnet')

cores <- parallel::detectCores()
rand_forest_ranger_spec <-
  rand_forest(mtry = tune(), min_n = tune()) %>%
  set_engine('ranger', num.threads = cores) %>%
  set_mode('classification')

svm_poly_kernlab_spec <-
  svm_poly(cost = tune(), degree = tune(), scale_factor = tune(), margin = tune()) %>%
  set_engine('kernlab') %>%
  set_mode('classification')

boost_tree_xgboost_spec <-
  boost_tree(tree_depth = tune(), learn_rate = tune(), 
             min_n = tune(), loss_reduction = tune(), mtry = tune(),
             sample_size = tune(), stop_iter = tune()) %>%
  set_engine('xgboost') %>%
  set_mode('classification')

Build the recipe and workflow:

library(themis)

recipe <-
  recipe(worm ~ ., data = model_data) %>%
  step_nzv(all_predictors()) %>%
  step_normalize(all_predictors()) %>%
  step_corr(all_numeric_predictors(), threshold = .5) %>%
  step_smote(worm)

prep <- prep(recipe)
juice <- juice(prep)

prep
## Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor         24
## 
## Training data contained 4821 data points and no missing data.
## 
## Operations:
## 
## Sparse, unbalanced variable filter removed area_shape_euler_number, area_shape_max... [trained]
## Centering and scaling for area_shape_area, area_shape_bounding_box_area, ... [trained]
## Correlation filter removed area_shape_bounding_box_maximum_y, area_s... [trained]
## SMOTE based on worm [trained]
glimpse(juice)
## Rows: 13,132
## Columns: 6
## $ area_shape_bounding_box_minimum_x <dbl> -1.2442535, -1.1710140, -1.0977744, …
## $ area_shape_bounding_box_minimum_y <dbl> 0.5178302, 0.5178302, 0.5178302, 0.5…
## $ area_shape_eccentricity           <dbl> 0.30806485, 0.13513151, 0.19917063, …
## $ area_shape_minor_axis_length      <dbl> -0.7834696, -1.1095671, -0.9766285, …
## $ area_shape_orientation            <dbl> -0.076386694, -0.113077317, -0.03487…
## $ worm                              <fct> Debris, Debris, Single worm, Single …
recipe2 <- recipe
recipe2$steps[[3]] <- update(recipe2$steps[[3]], skip = TRUE)

dt_workflow <-
  workflow() %>%
  add_model(decision_tree_rpart_spec) %>%
  add_recipe(recipe2)

mn_workflow <-
  workflow() %>%
  add_model(multinom_reg_glmnet_spec) %>%
  add_recipe(recipe)

rf_workflow <-
  workflow() %>%
  add_model(rand_forest_ranger_spec) %>%
  add_recipe(recipe2)

svm_poly_workflow <-
  workflow() %>%
  add_model(svm_poly_kernlab_spec) %>%
  add_recipe(recipe)

xg_workflow <-
  workflow() %>% 
  add_model(boost_tree_xgboost_spec) %>% 
  add_recipe(recipe)

Tune the models

Decision tree

dt_grid <- grid_regular(cost_complexity(),
                        tree_depth(),
                        min_n(),
                        levels = 5)

# tune on the train data
dt_tune <-
  dt_workflow %>%
  tune_grid(
    resamples = folds,
    grid = dt_grid,
    control = control_grid(save_pred = TRUE,
                           verbose = TRUE),
    metrics = metric_set(roc_auc, sens)
  )

write_rds(dt_tune, here('code', 'rds', 'dt_tune.rds'))

dt_tune <- read_rds(here('code', 'rds', 'dt_tune.rds'))

# extract the best decision tree
best_tree <- dt_tune %>%
  select_best("roc_auc")

# print metrics
(dt_metrics <- dt_tune %>% 
    collect_metrics() %>% 
    semi_join(best_tree) %>% 
    select(.metric:.config) %>% 
    mutate(model = 'Decision tree'))
## # A tibble: 2 × 7
##   .metric .estimator  mean     n std_err .config                model        
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>                  <chr>        
## 1 roc_auc hand_till  0.813    10  0.0113 Preprocessor1_Model114 Decision tree
## 2 sens    macro      0.586    10  0.0168 Preprocessor1_Model114 Decision tree
# finalize the wf with the best tree
dt_workflow <-
  dt_workflow %>%
  finalize_workflow(best_tree)

# generate predictions on the hold-out test data
dt_auc <-
  dt_tune %>%
  collect_predictions(parameters = best_tree) %>%
  roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
  mutate(model = "Decision tree")

dt_auc %>%
  autoplot()

Multinomial regression

mn_grid <- grid_regular(mixture(),
                        penalty())

mn_tune <-
  mn_workflow %>%
  tune_grid(
    resamples = folds,
    grid = mn_grid,
    control = control_grid(save_pred = TRUE,
                           verbose = TRUE),
    metrics = metric_set(roc_auc, sens))

write_rds(mn_tune, here('code', 'rds', 'mn_tune.rds'))

mn_tune <- read_rds(here('code', 'rds', 'mn_tune.rds'))

# extract the best model
best_mn <- mn_tune %>%
  select_best("roc_auc")

# print metrics
(mn_metrics <- mn_tune %>% 
    collect_metrics() %>% 
    semi_join(best_mn) %>% 
    select(.metric:.config) %>% 
    mutate(model = 'Multinomial regression'))
## # A tibble: 2 × 7
##   .metric .estimator  mean     n std_err .config              model             
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>                <chr>             
## 1 roc_auc hand_till  0.790    10 0.00847 Preprocessor1_Model7 Multinomial regre…
## 2 sens    macro      0.537    10 0.0174  Preprocessor1_Model7 Multinomial regre…
# finalize the wf with the best model
mn_workflow <-
  mn_workflow %>%
  finalize_workflow(best_mn)

# generate predictions on the hold-out test data
mn_auc <-
  mn_tune %>%
  collect_predictions(parameters = best_mn) %>%
  roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
  mutate(model = "Multinomial regression")

mn_auc %>%
  autoplot()

Random forest

rf_grid <- grid_regular(finalize(mtry(), model_data),
                        min_n())

rf_tune <-
  rf_workflow %>%
  tune_grid(
    resamples = folds,
    grid = rf_grid,
    control = control_grid(save_pred = TRUE,
                           verbose = TRUE),
    metrics = metric_set(roc_auc, sens))

write_rds(rf_tune, here('code', 'rds', 'rf_tune.rds'))

rf_tune <- read_rds(here('code', 'rds', 'rf_tune.rds'))

# extract the best decision model
best_rf <- rf_tune %>%
  select_best("roc_auc")

# print metrics
(rf_metrics <- rf_tune %>% 
    collect_metrics() %>% 
    semi_join(best_rf) %>% 
    select(.metric:.config) %>% 
    mutate(model = 'Random forest'))
## # A tibble: 2 × 7
##   .metric .estimator  mean     n std_err .config              model        
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>                <chr>        
## 1 roc_auc hand_till  0.845    10  0.0110 Preprocessor1_Model7 Random forest
## 2 sens    macro      0.605    10  0.0170 Preprocessor1_Model7 Random forest
# finalize the wf with the best model
rf_workflow <-
  rf_workflow %>%
  finalize_workflow(best_rf)

# generate predictions on the hold-out test data
rf_auc <-
  rf_tune %>%
  collect_predictions(parameters = best_rf) %>%
  roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
  mutate(model = "Random forest")

rf_auc %>%
  autoplot()

SVM

svm_grid <- grid_regular(cost(),
                         degree(),
                         scale_factor(),
                         svm_margin())

svm_tune <-
  svm_poly_workflow %>%
  tune_grid(
    resamples = folds,
    grid = svm_grid,
    control = control_grid(save_pred = TRUE,
                           verbose = TRUE),
    metrics = metric_set(roc_auc, sens))

write_rds(svm_tune, here('code', 'rds', 'svm_tune.rds'))

svm_tune <- read_rds(here('code', 'rds', 'svm_tune.rds'))

# extract the best svm
best_svm <- svm_tune %>%
  select_best("roc_auc")

# print metrics
(svm_metrics <- svm_tune %>% 
    collect_metrics() %>% 
    semi_join(best_svm) %>% 
    select(.metric:.config) %>% 
    mutate(model = 'SVM'))
## # A tibble: 2 × 7
##   .metric .estimator  mean     n std_err .config               model
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 <chr>
## 1 roc_auc hand_till  0.805     4 0.00611 Preprocessor1_Model27 SVM  
## 2 sens    macro      0.615     4 0.0166  Preprocessor1_Model27 SVM
# finalize the wf with the best svm
svm_workflow <-
  svm_poly_workflow %>%
  finalize_workflow(best_svm)

# generate predictions on the hold-out test data
svm_auc <-
  svm_tune %>%
  collect_predictions(parameters = best_svm) %>%
  roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
  mutate(model = "SVM")

svm_auc %>%
  autoplot()

XGBoost

xg_grid <- grid_latin_hypercube(
  tree_depth(),
  min_n(),
  loss_reduction(),
  sample_size = sample_prop(),
  finalize(mtry(), train_data),
  learn_rate(),
  stop_iter(),
  size = 30
)

# tune on the train data
xg_tune <-
  xg_workflow %>%
  tune_grid(
    resamples = folds,
    grid = xg_grid,
    control = control_grid(save_pred = TRUE,
                           verbose = TRUE),
    metrics = metric_set(roc_auc, sens)
  )

write_rds(xg_tune, here('code', 'rds', 'xg_tune.rds'))

xg_tune <- read_rds(here('code', 'rds', 'xg_tune.rds'))

# extract the best decision tree
best_xg <- xg_tune %>%
  select_best("roc_auc")

# print metrics
(xg_metrics <- xg_tune %>% 
    collect_metrics() %>% 
    semi_join(best_xg) %>% 
    select(.metric:.config) %>% 
    mutate(model = 'XGBoost'))
## # A tibble: 2 × 7
##   .metric .estimator  mean     n std_err .config               model  
##   <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 <chr>  
## 1 roc_auc hand_till  0.841    10 0.00973 Preprocessor1_Model10 XGBoost
## 2 sens    macro      0.601    10 0.0175  Preprocessor1_Model10 XGBoost
# finalize the wf with the best tree
xg_workflow <-
  xg_workflow %>%
  finalize_workflow(best_xg)

# generate predictions on the hold-out test data
xg_auc <-
  xg_tune %>%
  collect_predictions(parameters = best_xg) %>%
  roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
  mutate(model = "XGBoost")

xg_auc %>%
  autoplot()

Evaluate models

Evaluate using ROC AUC.

(all_metrics <- bind_rows(dt_metrics, mn_metrics, rf_metrics, svm_metrics, xg_metrics) %>% 
   group_by(.metric) %>% 
   arrange(-mean))
## # A tibble: 10 × 7
## # Groups:   .metric [2]
##    .metric .estimator  mean     n std_err .config                model          
##    <chr>   <chr>      <dbl> <int>   <dbl> <chr>                  <chr>          
##  1 roc_auc hand_till  0.845    10 0.0110  Preprocessor1_Model7   Random forest  
##  2 roc_auc hand_till  0.841    10 0.00973 Preprocessor1_Model10  XGBoost        
##  3 roc_auc hand_till  0.813    10 0.0113  Preprocessor1_Model114 Decision tree  
##  4 roc_auc hand_till  0.805     4 0.00611 Preprocessor1_Model27  SVM            
##  5 roc_auc hand_till  0.790    10 0.00847 Preprocessor1_Model7   Multinomial re…
##  6 sens    macro      0.615     4 0.0166  Preprocessor1_Model27  SVM            
##  7 sens    macro      0.605    10 0.0170  Preprocessor1_Model7   Random forest  
##  8 sens    macro      0.601    10 0.0175  Preprocessor1_Model10  XGBoost        
##  9 sens    macro      0.586    10 0.0168  Preprocessor1_Model114 Decision tree  
## 10 sens    macro      0.537    10 0.0174  Preprocessor1_Model7   Multinomial re…
(all_models <- bind_rows(dt_auc, mn_auc, rf_auc, svm_auc, xg_auc) %>%
  ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
  geom_path(lwd = 1.5, alpha = 0.8) +
  geom_abline(lty = 3) +
  coord_equal() +
  scale_color_viridis_d(option = "plasma") +
  facet_wrap(facets = vars(.level)) +
  theme_minimal() +
  NULL)

Random forest and XGBoost consistently perform the best across all 4 classes. Now I fit to the test data using the best parameters and evaluate the model’s performance.

mtry <- best_xg$mtry
trees <- 1000
min_n <- best_xg$min_n
tree_depth <- best_xg$tree_depth
learn_rate <- best_xg$learn_rate
loss_reduction <- best_xg$loss_reduction

last_mod <-
  boost_tree(mtry = mtry,
             trees = trees,
             min_n = min_n,
             tree_depth = tree_depth,
             learn_rate = learn_rate,
             loss_reduction = loss_reduction) %>%
  set_engine("xgboost", importance = "impurity") %>%
  set_mode("classification")

last_workflow <-
  xg_workflow %>%
  update_model(last_mod)

set.seed(345)
last_fit <-
  last_workflow %>%
  last_fit(data_split, 
           metrics = metric_set(roc_auc, sens))

collect_metrics(last_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 sens    macro          0.590 Preprocessor1_Model1
## 2 roc_auc hand_till      0.832 Preprocessor1_Model1
last_fit %>%
  extract_fit_engine() %>% 
  vip::vip() +
  theme_minimal()

(final_auc <-
  last_fit %>%
  collect_predictions() %>%
  roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
  autoplot())

last_fit %>%
  collect_predictions() %>%
  conf_mat(truth = worm, estimate = .pred_class) %>%
  autoplot()

The model actually performs better on the test data than the training data, indicating that we aren’t overfitting.

In a situation where we probably have more data points than are truly necessary to be able to draw defensible inferences, we are most concerned with accurate identification of a Single Worm. By that I mean that we are ok if the false negative rate is high (i.e., a Single Worm is identified as either Debris, Partial, or Multiple). Thus, we want a high true positive and low false positive for Single Worms, or high positive predictive value (PPV) and high Sensitivity. Using the selected model and the test set, here’s what would happen if we only kept the StraightenedWorms to be predicted as a Single Worm:

last_fit %>% 
  collect_predictions() %>% 
  filter(.pred_class == 'Single worm') %>%
  conf_mat(truth = worm, estimate = .pred_class)
##                 Truth
## Prediction       Debris Multiple worms Partial worm Single worm
##   Debris              0              0            0           0
##   Multiple worms      0              0            0           0
##   Partial worm        0              0            0           0
##   Single worm        31             39           38         598
last_fit %>% 
  collect_predictions() %>% 
  group_by(worm) %>% 
  summarise(n())
## # A tibble: 4 × 2
##   worm           `n()`
##   <fct>          <int>
## 1 Debris           155
## 2 Multiple worms    98
## 3 Partial worm     137
## 4 Single worm      817
final_wf <- last_fit %>% 
  extract_workflow()

write_rds(final_wf, here('code', 'rds', 'final_workflow.rds'))

pre_filter <- annotations %>% 
  select(worm, area_shape_major_axis_length) %>% 
  ggplot(aes(x = worm, y = area_shape_major_axis_length)) +
  geom_quasirandom(aes(color = worm)) +
  geom_text(data = . %>% group_by(worm) %>% summarise(n = n()),
            aes(label = n), y = 550) +
  theme_minimal() +
  labs(title = 'Pre-filter') +
  lims(y = c(0, 600)) +
  theme(legend.position = 'empty')
  
post_filter <- augment(final_wf, annotations) %>% 
  filter(.pred_class == 'Single worm') %>% 
  select(worm, area_shape_major_axis_length) %>% 
  ggplot(aes(x = worm, y = area_shape_major_axis_length)) +
  geom_quasirandom(aes(color = worm)) +
  geom_text(data = . %>% group_by(worm) %>% summarise(n = n()),
            aes(label = n), y = 550) +
  theme_minimal() +
  labs(title = 'Post-filter') +
  lims(y = c(0, 600)) +
  theme(legend.position = 'empty')

cowplot::plot_grid(pre_filter, post_filter, nrow = 1, align = 'h', axis = 'tb')

(percent_loss <- annotations %>% 
  select(worm, area_shape_major_axis_length) %>% 
  group_by(worm) %>% 
  summarise(pre_filter = n()) %>%
  left_join(
    augment(final_wf, annotations) %>% 
      filter(.pred_class == 'Single worm') %>% 
      select(worm, area_shape_major_axis_length) %>% 
      group_by(worm) %>% 
      summarise(post_filter = n())
  ) %>% 
  mutate(percent_loss = 1 - post_filter / pre_filter))
## # A tibble: 4 × 4
##   worm           pre_filter post_filter percent_loss
##   <fct>               <int>       <int>        <dbl>
## 1 Debris                589          63        0.893
## 2 Multiple worms        366          71        0.806
## 3 Partial worm          583          94        0.839
## 4 Single worm          3283        2696        0.179